{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4fe2bca5-2d2a-400b-9815-49c10f3f9305",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow.\n",
      "WARNING:root:For using TFDSDataLoader please install tensorflow_datasets.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import jax_md\n",
    "import numpy as np\n",
    "\n",
    "from ase.io import read\n",
    "from jax_md import units\n",
    "from typing import Dict\n",
    "\n",
    "from so3lr import to_jax_md\n",
    "from so3lr import So3lrPotential\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c7745db-7b82-4392-9bc7-463022417027",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to perform simulations in float64 you have to call this before any JAX compuation\n",
    "jax.config.update('jax_enable_x64', True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "81fd530f-c7d7-418f-a2aa-519b5fad6136",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some helper functions used throughout the example notebook.\n",
    "\n",
    "# Default nose hoover chain parameters.\n",
    "def default_nhc_kwargs(\n",
    "    tau: jnp.float32, \n",
    "    overrides: Dict\n",
    ") -> Dict:\n",
    "    \n",
    "    default_kwargs = {\n",
    "        'chain_length': 3, \n",
    "        'chain_steps': 2, \n",
    "        'sy_steps': 3,\n",
    "        'tau': tau\n",
    "    }\n",
    "    \n",
    "    if overrides is None:\n",
    "        return default_kwargs\n",
    "  \n",
    "    return {\n",
    "      k: overrides.get(k, default_kwargs[k]) for k in default_kwargs\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75141afa-35b3-42c1-bbd7-b46447c2fa34",
   "metadata": {},
   "source": [
    "# Read the Molecule\n",
    "\n",
    "Start by reading the molecular structure. Here we assume some ASE digestable file, e.g. and `XYZ` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e5b49475-023c-49c5-bfd3-787efa9c79d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read some molecular structure. Here we read the DHA structure from the test directory.\n",
    "path_to_xyz = '../tests/test_data/dha.xyz'\n",
    "atoms = read(path_to_xyz, index=-1)\n",
    "\n",
    "positions_init = jnp.array(atoms.get_positions())\n",
    "species = jnp.array(atoms.get_atomic_numbers())\n",
    "masses = jnp.array(atoms.get_masses())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6ab6858-39a9-4e07-84d1-5f84df04f9b7",
   "metadata": {},
   "source": [
    "# Prepare JAX-MD Energy Function\n",
    "\n",
    "JAX-MD is build around the energy function and the neighborlist. Since we have short- and long-range neighbors, we have two neighborlists compared to the standard case of only a single one. We use the convenience interface `to_jax_md` provided within this repository, which returns the SO3LR energy function as well as the two neighborlists. It takes as input the `So3lrPotential` itself, as well as a displacement function which determines how update the atomic positions. The `displacement` function calculates the displacement vectors between atoms and can differ, e.g. based on the fact if simulations are performed under perdioic boundary conditions (PBCs) or not. Also positions can be either represented in real or fractional coordinates. For more details on displacement functions as well as on fractional coordinates we refer to the JAX-MD docs. Here we simulate DHA in vacuum (no PBCs) which allows to use `jax_md.space.free` to obtain an displacement function. It also returns a `shift` function which determines how to update the atomic positions during simulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "41943830-f8e4-41a4-bf49-8c1826bdbc1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:CheckpointMetadata file does not exist: /Users/thorbenfrank/Documents/git/so3lr/so3lr/params/checkpoints/ckpt_5150000/_CHECKPOINT_METADATA\n",
      "WARNING:absl:`StandardCheckpointHandler` expects a target tree to be provided for restore. Not doing so is generally UNSAFE unless you know the present topology to be the same one as the checkpoint was saved under.\n",
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5191: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.\n",
      "  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)\n",
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n",
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/ops/scatter.py:92: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "displacement, shift = jax_md.space.free()\n",
    "box = None  # No simulation box. \n",
    "\n",
    "neighbor_fn, neighbor_fn_lr, energy_fn = to_jax_md(\n",
    "    potential=So3lrPotential(\n",
    "        dtype=jnp.float64  # or float32 for single precision\n",
    "    ),\n",
    "    displacement_or_metric=displacement,\n",
    "    box_size=box,\n",
    "    species=species,\n",
    "    capacity_multiplier=1.25,\n",
    "    buffer_size_multiplier_sr=1.25,\n",
    "    buffer_size_multiplier_lr=1.25,\n",
    "    minimum_cell_size_multiplier_sr=1.0,\n",
    "    disable_cell_list=True,  # Cell list partitioning can only be applied if there is a simulation box.\n",
    "    fractional_coordinates=False\n",
    ")\n",
    "\n",
    "# Energy function.\n",
    "energy_fn = jax.jit(energy_fn)\n",
    "force_fn = jax.jit(jax_md.quantity.force(energy_fn))\n",
    "\n",
    "# Initialize the short and long-range neighbor lists.\n",
    "nbrs = neighbor_fn.allocate(\n",
    "    positions_init,\n",
    "    box=box\n",
    ")\n",
    "nbrs_lr = neighbor_fn_lr.allocate(\n",
    "    positions_init,\n",
    "    box=box\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ecb4081e-9594-4ec2-b9d6-4637d6ce22f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(-94.89951356, dtype=float64)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check that the energy function is working.\n",
    "energy_fn(\n",
    "    positions_init, \n",
    "    neighbor=nbrs.idx, \n",
    "    neighbor_lr=nbrs_lr.idx, \n",
    "    box=box\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f58f07d1-5ee4-480e-834b-a5fd3950f0e1",
   "metadata": {},
   "source": [
    "# Structure Relaxation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "90bbd18b-3498-48bd-a288-82c52c56b073",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step\tE\tFmax\n",
      "----------------------------------------\n",
      "0\t-95.38\t0.51\n",
      "1\t-95.57\t0.23\n",
      "2\t-95.69\t0.15\n",
      "3\t-95.77\t0.12\n",
      "4\t-95.81\t0.07\n"
     ]
    }
   ],
   "source": [
    "min_cycles = 5\n",
    "min_steps = 10\n",
    "\n",
    "fire_init, fire_apply = jax_md.minimize.fire_descent(\n",
    "    energy_fn, \n",
    "    shift, \n",
    "    dt_start = 0.05, \n",
    "    dt_max = 0.1, \n",
    "    n_min = 2\n",
    ")\n",
    "\n",
    "fire_apply = jax.jit(fire_apply)\n",
    "fire_state = fire_init(\n",
    "    positions_init, \n",
    "    box=box, \n",
    "    neighbor=nbrs.idx,\n",
    "    neighbor_lr=nbrs_lr.idx\n",
    ")\n",
    "\n",
    "@jax.jit\n",
    "def step_fire_fn(i, fire_state):\n",
    "    \n",
    "    fire_state, nbrs, nbrs_lr = fire_state\n",
    "    \n",
    "    fire_state = fire_apply(\n",
    "        fire_state, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    nbrs = nbrs.update(\n",
    "        fire_state.position,\n",
    "        neighbor=nbrs.idx\n",
    "    )\n",
    "    \n",
    "    nbrs_lr = nbrs_lr.update(\n",
    "        fire_state.position,\n",
    "        neighbor_lr=nbrs_lr.idx\n",
    "    )\n",
    "    \n",
    "    return fire_state, nbrs, nbrs_lr\n",
    "\n",
    "print('Step\\tE\\tFmax')\n",
    "print('----------------------------------------')\n",
    "for i in range(min_cycles):\n",
    "    fire_state, nbrs, nbrs_lr = jax.lax.fori_loop(\n",
    "        0, \n",
    "        min_steps, \n",
    "        step_fire_fn, \n",
    "        (fire_state, nbrs, nbrs_lr)\n",
    "    )\n",
    "    \n",
    "    E = energy_fn(\n",
    "        fire_state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "\n",
    "    F = force_fn(\n",
    "        fire_state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    print('{}\\t{:.2f}\\t{:.2f}'.format(i, E, np.abs(F).max()))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b6eba7ba-0ab5-439c-a184-dab466225c3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[-0.07376099,  0.0479126 ,  0.06988525],\n",
       "       [ 0.0149231 , -0.03244019,  0.03833008],\n",
       "       [-0.01501465, -0.00018311, -0.04849243]], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Print the delta between original and optimized positions for the first three atoms.\n",
    "(positions_init - fire_state.position)[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32ce6358-7245-4e00-a95b-c673a1b286d7",
   "metadata": {},
   "source": [
    "# Nose-Hoover Chain NVT Simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4b390206-c34a-4bee-aaa8-057f0f8b8b4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:183: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  return asarray(x, dtype=self.dtype)\n"
     ]
    }
   ],
   "source": [
    "# Simulation parameters\n",
    "\n",
    "timestep = 0.0005  # Time step in ps\n",
    "nvt_cycles = 5  # Number of Cycles in the NPT.\n",
    "nvt_steps = 50  # Number of NPT steps per cylce. The total number of MD steps equals npt_cylces * npt_steps\n",
    "\n",
    "T_init = None  # Initial temperature.\n",
    "T_nvt = 300  # Target temperature. \n",
    "\n",
    "chain = 3  # Number of chains in the Nose-Hoover chain.\n",
    "chain_steps = 2  # Number of steps per chain.\n",
    "sy_steps = 3\n",
    "thermo = 100  # Thermo value in the Nose-Hoover chain. \n",
    "\n",
    "# Set the temprature at initialization\n",
    "if T_init is None:\n",
    "    T_init = T_nvt \n",
    "else:\n",
    "    T_init = T_init\n",
    "\n",
    "# Dictionary with the NHC settings.\n",
    "new_nhc_kwargs = {\n",
    "    'chain_length': chain, \n",
    "    'chain_steps': chain_steps, \n",
    "    'sy_steps': sy_steps\n",
    "}\n",
    "\n",
    "# Convert to metal unit system.\n",
    "unit = units.metal_unit_system()\n",
    "\n",
    "timestep = timestep * unit['time']\n",
    "T_init = T_init * unit['temperature']\n",
    "\n",
    "rng_key = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ac16efdc-76dd-4ee2-bcff-1bc48407776f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/thorbenfrank/Documents/venvs/so3lr/lib/python3.12/site-packages/jax/_src/numpy/reductions.py:213: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  return _reduction(a, \"sum\", np.sum, lax.add, 0, preproc=_cast_to_numeric,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step\tKE\tPE\tTotal Energy\tTemperature\tH\ttime/steps\n",
      "-----------------------------------------------------------------------------------\n",
      "0\t0.30\t-95.61\t-95.308\t41.6\t-95.282\t0.2221\n",
      "50\t0.38\t-95.60\t-95.219\t52.9\t-95.180\t0.0186\n",
      "100\t0.47\t-95.57\t-95.105\t64.8\t-95.128\t0.0183\n",
      "150\t0.57\t-95.51\t-94.950\t78.1\t-95.106\t0.0187\n",
      "200\t0.74\t-95.48\t-94.744\t102.3\t-95.082\t0.0179\n",
      "Total_time:  15.108027219772339\n"
     ]
    }
   ],
   "source": [
    "# Chosse Nose-Hoover thermostat.\n",
    "init_fn, apply_fn = jax_md.simulate.nvt_nose_hoover(\n",
    "    energy_fn, \n",
    "    shift, \n",
    "    dt=timestep, \n",
    "    kT=T_init,\n",
    "    box=box,\n",
    "    thermostat_kwargs=default_nhc_kwargs(thermo * timestep, new_nhc_kwargs)\n",
    ")\n",
    "\n",
    "apply_fn = jax.jit(apply_fn)\n",
    "init_fn = jax.jit(init_fn)\n",
    "\n",
    "# Initialize state using position and neigbhors structure relaxation.\n",
    "state = init_fn(\n",
    "    rng_key, \n",
    "    fire_state.position, \n",
    "    box=box, \n",
    "    neighbor=nbrs.idx, \n",
    "    neighbor_lr=nbrs_lr.idx,\n",
    "    kT=T_init,\n",
    "    mass=masses\n",
    ")\n",
    "\n",
    "@jax.jit\n",
    "def step_nvt_fn(i, state):\n",
    "    state, nbrs, nbrs_lr, box, temp_i = state\n",
    "    \n",
    "    state = apply_fn(\n",
    "        state, \n",
    "        neighbor=nbrs.idx, \n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        kT=temp_i,\n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    nbrs = nbrs.update(\n",
    "        state.position, \n",
    "        neighbor=nbrs.idx, \n",
    "        box = box\n",
    "    )\n",
    "    \n",
    "    nbrs_lr = nbrs_lr.update(\n",
    "        state.position, \n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box = box\n",
    "    )\n",
    "    \n",
    "    return state, nbrs, nbrs_lr, box, temp_i\n",
    "\n",
    "# Track total time and step times averaged over cycle.\n",
    "total_time = time.time()\n",
    "\n",
    "positions_md = []\n",
    "\n",
    "print('Step\\tKE\\tPE\\tTotal Energy\\tTemperature\\tH\\ttime/steps')\n",
    "print('-----------------------------------------------------------------------------------')\n",
    "for i in range(nvt_cycles):\n",
    "    \n",
    "    temp_i = T_nvt\n",
    "\n",
    "    old_time = time.time()\n",
    "    \n",
    "    # Do `nvt_steps` NVT steps.\n",
    "    new_state, nbrs, nbrs_lr, new_box, temp_i = jax.block_until_ready(\n",
    "        jax.lax.fori_loop(\n",
    "            0,\n",
    "            nvt_steps,\n",
    "            step_nvt_fn,\n",
    "            (state, nbrs, nbrs_lr, box, temp_i)  # carry state is tuple\n",
    "        )\n",
    "    )\n",
    "    \n",
    "    new_time = time.time()\n",
    "    \n",
    "    # Check for overflor of both sr and lr neighbors.\n",
    "    if nbrs.did_buffer_overflow:\n",
    "        print('Neighbor list overflowed, reallocating.')\n",
    "        nbrs = neighbor_fn.allocate(state.position, box = box)\n",
    "        if nbrs_lr.did_buffer_overflow:\n",
    "            print('Long-range neighbor list also overflowed, reallocating.')\n",
    "            nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)\n",
    "    elif nbrs_lr.did_buffer_overflow:\n",
    "        print('Long-range neighbor list overflowed, reallocating.')\n",
    "        nbrs_lr = neighbor_fn_lr.allocate(state.position, box = box)\n",
    "    else:\n",
    "        state = new_state\n",
    "        box = new_box\n",
    "\n",
    "    # Calculate some quantities for printing\n",
    "    KE = jax_md.quantity.kinetic_energy(\n",
    "        momentum=state.momentum,\n",
    "        mass=state.mass\n",
    "    )\n",
    "    \n",
    "    PE = energy_fn(\n",
    "        state.position,\n",
    "        neighbor=nbrs.idx,\n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box=box\n",
    "    )\n",
    "    \n",
    "    T = jax_md.quantity.temperature(\n",
    "        momentum=state.momentum,\n",
    "        mass = state.mass\n",
    "    ) / unit['temperature']\n",
    "    \n",
    "    H = jax_md.simulate.nvt_nose_hoover_invariant(\n",
    "        energy_fn, \n",
    "        state, \n",
    "        kT=temp_i,\n",
    "        neighbor=nbrs.idx,\n",
    "        neighbor_lr=nbrs_lr.idx, \n",
    "        box=box\n",
    "    )\n",
    "\n",
    "    positions_md.append(np.array(state.position))\n",
    "    \n",
    "    print(f'{i*nvt_steps}\\t{KE:.2f}\\t{PE:.2f}\\t{KE+PE:.3f}\\t{T:.1f}\\t{H:.3f}\\t{(new_time - old_time) / nvt_steps:.4f}')\n",
    "\n",
    "print('Total_time: ', time.time()-total_time)\n",
    "\n",
    "# Clear all caches\n",
    "jax.clear_caches()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23e10ea7-e3b5-4ea4-a680-d62e4ee6a840",
   "metadata": {},
   "source": [
    "# Visualization and IO\n",
    "\n",
    "For now, we concatenated the positions to a simple python list `positions_md`. This allows to directly visualize the trajectory and write the frames to `xyz` after the simulation. However, for production runs it might be neccessary to save the frames (and potentially other information) along the way. For large structures, using `ase.io.write` can be very very slow and reduce the simulation time by sevaral order of magnitudes. To learn how to efficiently write frames and other statistics during simulation, check out the `production_run.ipynb` example notebooks. It uses `.hdf5` files to efficiently perform save operations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1748cb-e94e-4794-8b89-87cd6c118e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to do visualization in the notebook, install nglview by doing `pip install nglview` in your virtualenv\n",
    "import nglview as nv\n",
    "from ase import Atoms\n",
    "\n",
    "atoms_traj = []\n",
    "for positions in positions_md:\n",
    "    atoms_traj.append(\n",
    "        Atoms(numbers=np.array(species), positions=positions),\n",
    "    )\n",
    "\n",
    "nv.show_asetraj(atoms_traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2af02131-6f64-44a0-9acd-3ec5ad20815e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the frames to xyz.\n",
    "from ase.io import write\n",
    "\n",
    "for frame in atoms_traj:\n",
    "    write( \n",
    "        'nvt_md_trajectory.xyz',\n",
    "        frame,\n",
    "        append=True\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "so3lr",
   "language": "python",
   "name": "so3lr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
